{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using Polymetis with Habitat\n",
    "\n",
    "Follow the official instructions on installing polymetis [here](https://facebookresearch.github.io/fairo/polymetis/installation.html) locally. As the polymetis server and client are lauched separately, follow the instructions below to launch the server locally and then run the remaining cells in this notebook to define the client-side controller. \n",
    "\n",
    "If jupyter was not launched from the polymetis environment, set up the polymetis environment as a kernel:\n",
    "1. Activate the polymetis environment\n",
    "2. Run the following command: `ipython kernel install --user --name=POLYMETIS_ENV_NAME`\n",
    "3. Reload this page\n",
    "4. Open Kernel > Change Kernel... and select the polymetis environment name"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The client configuration for habitat is included in polymetis. To use it, specify `robot_client=habitat_sim` and set the `habitat_scene_path` as an **absolute path** to your scene file."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```\n",
    "launch_robot.py robot_client=habitat_sim habitat_scene_path=/PATH/TO/scene use_real_time=false gui=true\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To change the robot model, specify the `robot_model`:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```\n",
    "launch_robot.py robot_client=habitat_sim robot_model=ROBOT_MODEL habitat_scene_path=/PATH/TO/scene use_real_time=false gui=true\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Defining a controller"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "lines_to_next_cell": 1
   },
   "outputs": [],
   "source": [
    "from typing import Dict\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torchcontrol as toco\n",
    "from polymetis import RobotInterface"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MySinePolicy(toco.PolicyModule):\n",
    "    \"\"\"\n",
    "    Custom policy that executes a sine trajectory on joint 6\n",
    "    (magnitude = 0.5 radian, frequency = 1 second)\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, time_horizon, hz, magnitude, period, kq, kqd, **kwargs):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            time_horizon (int):         Number of steps policy should execute\n",
    "            hz (double):                Frequency of controller\n",
    "            kq, kqd (torch.Tensor):     PD gains (1d array)\n",
    "        \"\"\"\n",
    "        super().__init__(**kwargs)\n",
    "\n",
    "        self.hz = hz\n",
    "        self.time_horizon = time_horizon\n",
    "        self.m = magnitude\n",
    "        self.T = period\n",
    "\n",
    "        # Initialize modules\n",
    "        self.feedback = toco.modules.JointSpacePD(kq, kqd)\n",
    "\n",
    "        # Initialize variables\n",
    "        self.steps = 0\n",
    "        self.q_initial = torch.zeros_like(kq)\n",
    "\n",
    "    def forward(self, state_dict: Dict[str, torch.Tensor]):\n",
    "        # Parse states\n",
    "        q_current = state_dict[\"joint_positions\"]\n",
    "        qd_current = state_dict[\"joint_velocities\"]\n",
    "\n",
    "        # Initialize\n",
    "        if self.steps == 0:\n",
    "            self.q_initial = q_current.clone()\n",
    "\n",
    "        # Compute reference position and velocity\n",
    "        q_desired = self.q_initial.clone()\n",
    "        q_desired[5] = self.q_initial[5] + self.m * torch.sin(\n",
    "            np.pi * self.steps / (self.hz * self.T)\n",
    "        )\n",
    "        qd_desired = torch.zeros_like(qd_current)\n",
    "\n",
    "        # Execute PD control\n",
    "        output = self.feedback(\n",
    "            q_current, qd_current, q_desired, torch.zeros_like(qd_current)\n",
    "        )\n",
    "\n",
    "        # Check termination\n",
    "        if self.steps > self.time_horizon:\n",
    "            self.set_terminated()\n",
    "        self.steps += 1\n",
    "\n",
    "        return {\"joint_torques\": output}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize robot interface\n",
    "robot = RobotInterface(\n",
    "    ip_address=\"localhost\",\n",
    ")\n",
    "\n",
    "\n",
    "# Reset\n",
    "robot.go_home()\n",
    "\n",
    "# Create policy instance\n",
    "hz = robot.metadata.hz\n",
    "default_kq = torch.Tensor(robot.metadata.default_Kq)\n",
    "default_kqd = torch.Tensor(robot.metadata.default_Kqd)\n",
    "policy = MySinePolicy(\n",
    "    time_horizon=5 * hz,\n",
    "    hz=hz,\n",
    "    magnitude=0.5,\n",
    "    period=2.0,\n",
    "    kq=default_kq,\n",
    "    kqd=default_kqd,\n",
    ")\n",
    "\n",
    "# Run policy\n",
    "print(\"\\nRunning custom sine policy ...\\n\")\n",
    "state_log = robot.send_torch_policy(policy)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "jupytext": {
   "cell_metadata_filter": "-all",
   "formats": "nb_python//py:percent,colabs//ipynb",
   "notebook_metadata_filter": "all"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
